{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "import os\n",
    "os.chdir('E:\\GitHub\\QA-abstract-and-reasoning')\n",
    "sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from utils.config_gpu import config_gpu\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "config_gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 参数准备\n",
    "## 模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 32768\n",
    "embedding_dim = 300\n",
    "embedding_matrix = np.zeros((32768, 300))\n",
    "dec_units = enc_units = 256\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造GRU的输入 enc_input (batch_size, enc_len)\n",
    "enc_input = tf.zeros((batch_size, 200))\n",
    "# 插入嵌入层 enc_input2 (batch_size, enc_len, embedding_dim)\n",
    "embedding = tf.keras.layers.Embedding(vocab_size, \n",
    "                                      embedding_dim, \n",
    "                                      weights=[embedding_matrix],\n",
    "                                      trainable=False)\n",
    "enc_input2 = embedding(enc_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GRU测试\n",
    "## 构造单\\双向GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "gru = tf.keras.layers.GRU(enc_units,\n",
    "                       return_sequences=True,\n",
    "                       return_state=True,\n",
    "                       recurrent_initializer='glorot_uniform')\n",
    "bigru = tf.keras.layers.Bidirectional(gru, merge_mode=\"concat\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 单向GRU测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: id=1612, shape=(64, 256), dtype=float32, numpy=\n",
       " array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取初始状态\n",
    "initial_state = gru.get_initial_state(enc_input2)\n",
    "initial_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 进入gru层\n",
    "# output 所有gru单元的输出 (batch_size, enc_len, enc_units)\n",
    "# state 最后一个gru单元的输出 (batch_size, enc_units)\n",
    "output, state = gru(enc_input2, initial_state=initial_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "output, state = gru(enc_input2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 双向GRU测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initial_state 传入的是个列表 因为bigru要初始化前向和后向的gru单元\n",
    "# 所以initial_state 的长度是2 [state1, state2]\n",
    "# (batch_size, enc_len, enc_units*2)\n",
    "# forward_state, backward_state (batch_size, enc_units*2)\n",
    "output, forward_state, backward_state = bigru(enc_input2, initial_state=initial_state*2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = tf.keras.layers.LSTM(enc_units,\n",
    "                       return_sequences=True,\n",
    "                       return_state=True,\n",
    "                       recurrent_initializer='glorot_uniform')\n",
    "bilstm = tf.keras.layers.Bidirectional(lstm, merge_mode=\"concat\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://img-blog.csdn.net/20180712120310214?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0FteV9tbQ==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70)\n",
    "\n",
    "![](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-o.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 单向\n",
    "与gru不同的是，lstm的输出会多一个"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_input = tf.zeros((batch_size, 200))\n",
    "enc_input2 = embedding(enc_input)\n",
    "# initial_state [state1, state2] state (batch_size, enc_units)\n",
    "initial_state = lstm.get_initial_state(enc_input2)\n",
    "# enc_output (batch_size, enc_len, enc_units)\n",
    "# c_t, h_t (batch_size, enc_units)\n",
    "enc_output, enc_hidden, c_t = lstm(enc_input2, initial_state=initial_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 试着搭建LSTM的seq2seq模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LSTM RNN网络范例\n",
    "class RNN(tf.keras.Model):\n",
    "    def __init__(self, num_chars, batch_size, seq_length):\n",
    "        super().__init__()\n",
    "        self.num_chars = num_chars\n",
    "        self.seq_length = seq_length\n",
    "        self.batch_size = batch_size\n",
    "        self.cell = tf.keras.layers.LSTMCell(units=256)\n",
    "        self.dense = tf.keras.layers.Dense(units=self.num_chars)\n",
    "\n",
    "    def call(self, inputs, from_logits=False):\n",
    "        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]\n",
    "        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)\n",
    "        for t in range(self.seq_length):\n",
    "            output, state = self.cell(inputs[:, t, :], state)\n",
    "        logits = self.dense(output)\n",
    "        if from_logits:\n",
    "            return logits\n",
    "        else:\n",
    "            return tf.nn.softmax(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM_Encoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, embedding_matrix, enc_units, batch_size):\n",
    "        super(LSTM_Encoder, self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.enc_units = enc_units\n",
    "        self.use_bi_lstm = False\n",
    "        # 双向\n",
    "        if self.use_bi_lstm:\n",
    "            self.enc_units = self.enc_units // 2\n",
    "\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],\n",
    "                                                   trainable=False)\n",
    "        self.lstm = tf.keras.layers.LSTM(self.enc_units,\n",
    "                                   return_sequences=True,\n",
    "                                   return_state=True,\n",
    "                                   recurrent_initializer='glorot_uniform')\n",
    "        \n",
    "\n",
    "        self.bi_lstm = tf.keras.layers.Bidirectional(self.lstm)\n",
    "\n",
    "    def call(self, enc_input):\n",
    "        # (batch_size, enc_len, embedding_dim)\n",
    "        enc_input_embedded = self.embedding(enc_input)\n",
    "        initial_state = self.lstm.get_initial_state(enc_input_embedded)\n",
    "\n",
    "        if self.use_bi_lstm:\n",
    "            # 是否使用双向GRU\n",
    "            output, forward_state, backward_state = self.bi_lstm(enc_input_embedded, initial_state=initial_state * 2)\n",
    "            enc_hidden = tf.keras.layers.concatenate([forward_state, backward_state], axis=-1)\n",
    "\n",
    "        else:\n",
    "            # 单向GRU\n",
    "            output, enc_hidden, c_t  = self.lstm(enc_input_embedded, initial_state=initial_state)\n",
    "\n",
    "        return output, enc_hidden, c_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BahdanauAttention(tf.keras.layers.Layer):\n",
    "    def __init__(self, units):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.W1 = tf.keras.layers.Dense(units)\n",
    "        self.W2 = tf.keras.layers.Dense(units)\n",
    "        self.V = tf.keras.layers.Dense(1)\n",
    "\n",
    "    def call(self, dec_hidden, enc_output):\n",
    "        # dec_hidden shape == (batch_size, dec_units)\n",
    "        # enc_output (batch_size, enc_len, enc_units)\n",
    "\n",
    "        # hidden_with_time_axis shape == (batch_size, 1, dec_units)\n",
    "        # we are doing this to perform addition to calculate the score\n",
    "        hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)\n",
    "\n",
    "        # we get 1 at the last axis because we are applying score to self.V\n",
    "        # self.V 括号内的维度为 (batch_size, enc_len, attn_units)\n",
    "        # score (batch_size, enc_len, 1)\n",
    "        score = self.V(tf.nn.tanh(\n",
    "            self.W1(enc_output) + self.W2(hidden_with_time_axis)))\n",
    "\n",
    "        # attention_weights (batch_size, enc_len, 1)\n",
    "        attention_weights = tf.nn.softmax(score, axis=1)\n",
    "\n",
    "        # # 使用注意力权重*编码器输出作为返回值，将来会作为解码器的输入\n",
    "        # enc_output (batch_size, enc_len, enc_units)\n",
    "        # attention_weights (batch_size, enc_len, 1)\n",
    "        context_vector = attention_weights * enc_output\n",
    "\n",
    "        # context_vector (batch_size, enc_units)\n",
    "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
    "        return context_vector, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM_Decoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, embedding_matrix, dec_units, batch_size):\n",
    "        super(LSTM_Decoder, self).__init__()\n",
    "        self.batch_sz = batch_size\n",
    "        self.dec_units = dec_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],\n",
    "                                                   trainable=False)\n",
    "        \n",
    "        self.lstm = tf.keras.layers.LSTMCell(units=self.dec_units,\n",
    "                                            recurrent_initializer='glorot_uniform')\n",
    "        self.attention = BahdanauAttention(self.dec_units)\n",
    "        self.fc = tf.keras.layers.Dense(vocab_size)\n",
    "\n",
    "    def call(self, dec_input, prev_dec_state, enc_output):\n",
    "        # dec_input 一个单词 (batch_size, )\n",
    "        \n",
    "        # prev_dec_state LSTM 的state有两个，用list装起来\n",
    "        # prev_dec_state[0] h_t (batch_size, units)\n",
    "        # prev_dec_state[1] c_t (batch_size, units)\n",
    "        \n",
    "        # enc_output 用来计算注意力\n",
    "        # dec_input (batch_size, embedding_dim)\n",
    "        dec_input = self.embedding(dec_input)\n",
    "        # 用h_t计算注意力\n",
    "        \n",
    "        # context_vector (batch_size, units)\n",
    "        # attention_weights (batch_size, enc_len, 1)\n",
    "        context_vector, attention_weights = self.attention(prev_dec_state[0], enc_output)\n",
    "        # tf.squeeze(attention_weights)\n",
    "        \n",
    "        #dec_input (batch_size, units+embedding_dim)\n",
    "        dec_input = tf.concat([context_vector, dec_input], axis=-1)\n",
    "\n",
    "        # dec_output (batch_size, units)\n",
    "        # `dec_state` same as `prev_dec_state`\n",
    "        # PS: dec_output.shape == dec_state[0].shape\n",
    "        dec_output, dec_state = self.lstm(dec_input, prev_state)\n",
    "        # 来自源代码:\n",
    "        # tf.keras.layers.LSTMCell 的返回值如下\n",
    "        # return h, [h, c]\n",
    "        \n",
    "        # pred (batch_size, vocab_size)\n",
    "        pred = self.fc(dec_output)\n",
    "        return pred, dec_state, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个LSTM编码器\n",
    "lstm_enc = LSTM_Encoder(vocab_size, embedding_dim, embedding_matrix, enc_units, batch_size)\n",
    "\n",
    "# 获得encoder的输出\n",
    "enc_output, enc_hidden, c_t = lstm_enc(enc_input)\n",
    "# 构造decoder的初始输入\n",
    "dec_input = tf.constant([32766] * batch_size)\n",
    "# 初始化LSTM decoder\n",
    "lstm_dec = LSTM_Decoder(vocab_size, embedding_dim, embedding_matrix, dec_units, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_state = [enc_hidden, c_t]\n",
    "# 计算decoder的输出\n",
    "pred, dec_state, attention_weights = lstm_dec(dec_input, prev_state, enc_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
